﻿#region -- 版 本 注 释 --
/****************************************************
* 文 件 名：
* Copyright(c) 王树羽
* CLR 版本: 4.5
* 创 建 人：王树羽
* 电子邮箱：674613047@qq.com
* 官方网站：https://www.cnblogs.com/shuyu
* 创建日期：2018-06-25 
* 文件描述：
******************************************************
* 修 改 人：
* 修改日期：
* 备注描述：
*******************************************************/
#endregion

using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Entity;
using System.Data.Entity.Infrastructure;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;

namespace CommonLib
{
    /// <summary>
    /// EntityFramework 开发帮助类
    /// </summary>
    public static class EntityFrameworkHelper
    {
        #region EntityFramework

        #region 添加

        /// <summary>
        /// 添加
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="entity">实体对象</param>
        /// <returns></returns>
        public static int ExAdd<T>(this DbContext DB, T entity) where T : class
        {
            #region 设置默认值

            Type t = typeof(T);
            PropertyInfo[] propInfos = t.GetProperties();
            foreach (var pi in propInfos)
            {
                if (pi.Name.ToUpper() == "GUID")
                {
                    var value = pi.GetValue(entity, null);
                    if (value.ToString() == Guid.Empty.ToString())
                    {
                        pi.SetValue((T)entity, Guid.NewGuid(), null);
                    }
                }
                if (pi.Name.ToUpper() == "CREATETIME")
                {
                    var value = pi.GetValue(entity, null);
                    if (value == null)
                    {
                        pi.SetValue((T)entity, DateTime.Now, null);
                    }
                }
                if (pi.Name.ToUpper() == "ISDELETE")
                {
                    var value = pi.GetValue(entity, null);
                    if (value == null)
                    {
                        pi.SetValue((T)entity, 0, null);
                    }
                }
                if (pi.Name.ToUpper() == "ORGLEVELS")
                {
                    var value = OrgLevels(entity);
                    if (value != "")
                    {
                        pi.SetValue((T)entity, value, null);
                    }
                }
                if (pi.Name.ToUpper() == "CITYLEVELS")
                {
                    var value = CityLevels(entity);
                    if (value != "")
                    {
                        pi.SetValue((T)entity, value, null);
                    }
                }
                string name = pi.Name;
            }
            #endregion

            DB.Entry<T>(entity).State = EntityState.Added;
            var i = DB.SaveChanges();
            return i;
        }

        /// <summary>
        /// 事务添加
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="entity">实体对象</param>
        /// <returns></returns>
        public static void ExTranAdd<T>(this DbContext DB, T entity) where T : class
        {
            #region 设置默认值

            Type t = typeof(T);
            PropertyInfo[] propInfos = t.GetProperties();
            foreach (var pi in propInfos)
            {
                if (pi.Name.ToUpper() == "GUID")
                {
                    var value = pi.GetValue(entity, null);
                    if (value.ToString() == Guid.Empty.ToString())
                    {
                        pi.SetValue((T)entity, Guid.NewGuid(), null);
                    }
                }
                if (pi.Name.ToUpper() == "CREATETIME")
                {
                    var value = pi.GetValue(entity, null);
                    if (value == null)
                    {
                        pi.SetValue((T)entity, DateTime.Now, null);
                    }
                }
                if (pi.Name.ToUpper() == "ISDELETE")
                {
                    var value = pi.GetValue(entity, null);
                    if (value == null)
                    {
                        pi.SetValue((T)entity, 0, null);
                    }
                }
                if (pi.Name.ToUpper() == "ORGLEVELS")
                {
                    var value = OrgLevels(entity);
                    if (value != "")
                    {
                        pi.SetValue((T)entity, value, null);
                    }
                }
                if (pi.Name.ToUpper() == "CITYLEVELS")
                {
                    var value = CityLevels(entity);
                    if (value != "")
                    {
                        pi.SetValue((T)entity, value, null);
                    }
                }
                string name = pi.Name;
            }
            #endregion

            DB.Entry<T>(entity).State = EntityState.Added;
        }

        #region 获取创建人组织机构层次

        private static string OrgLevels<T>(T entity)
        {
            var str = "";
            try
            {
                var levels = entity.GetType().GetProperty("OrgLevels").GetValue(entity, null);
                if (levels == null)
                {
                    var value = entity.GetType().GetProperty("CreateUserInfoID").GetValue(entity, null);
                    if (value != null)
                    {
                        var db = new CommonLib.DbHelper.Factory().IDBhelper;
                        str = db.ExecuteScalar("SELECT OrgLevels FROM LoginCache WHERE(UserInfoID = '" + value + "')");
                    }
                }
                else
                {
                    str = levels.ToString();
                }
            }
            catch { }
            return str;
        }
        private static string CityLevels<T>(T entity)
        {
            var str = "";
            try
            {
                var levels = entity.GetType().GetProperty("CityLevels").GetValue(entity, null);
                if (levels == null)
                {
                    var value = entity.GetType().GetProperty("CreateUserInfoID").GetValue(entity, null);
                    if (value != null)
                    {
                        var db = new CommonLib.DbHelper.Factory().IDBhelper;
                        str = db.ExecuteScalar("SELECT CityLevels FROM LoginCache WHERE(UserInfoID = '" + value + "')");
                    }
                }
                else
                {
                    str = levels.ToString();
                }
            }
            catch { }
            return str;
        }
        #endregion

        #endregion 添加

        #region 批量添加

        /// <summary>
        /// 批量添加 
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="DB"></param>
        /// <param name="list"></param>
        /// <returns></returns>
        public static int ExBatchAdd<T>(this DbContext DB, List<T> list) where T : class
        {
            foreach (var item in list)
            {
                DB.ExTranAdd(item);
            }
            var i = DB.SaveChanges();
            return i;
        }

        /// <summary>
        /// 批量事务添加 
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="DB"></param>
        /// <param name="list"></param>
        /// <returns></returns>
        public static void ExTranBatchAdd<T>(this DbContext DB, List<T> list) where T : class
        {
            foreach (var item in list)
            {
                DB.ExTranAdd(item);
            }
        }

        #endregion

        #region 更新

        /// <summary>
        /// 更新
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="entity">实体对象</param>
        /// <returns></returns>
        public static int ExUpdate<T>(this DbContext DB, T entity) where T : class
        {
            #region 设置默认值

            Type t = typeof(T);
            PropertyInfo[] propInfos = t.GetProperties();
            foreach (var pi in propInfos)
            {
                if (pi.Name.ToUpper() == "UpdateTime")
                {
                    var value = pi.GetValue(entity, null);
                    if (value == null)
                    {
                        pi.SetValue((T)entity, DateTime.Now, null);
                    }
                }
                string name = pi.Name;
            }
            #endregion

            // RemoveHoldingEntityInContext(entity);

            var objContext = ((IObjectContextAdapter)DB).ObjectContext;
            var objSet = objContext.CreateObjectSet<T>();
            var entityKey = objContext.CreateEntityKey(objSet.EntitySet.Name, entity);

            Object foundEntity;
            var exists = objContext.TryGetObjectByKey(entityKey, out foundEntity);

            if (exists)
            {
                objContext.Detach(foundEntity);
            }



            DB.Set<T>().Attach(entity);
            DB.Entry<T>(entity).State = EntityState.Modified;
            var i = DB.SaveChanges();
            return i;
        }




        /// <summary>
        /// 事务更新
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="entity">实体对象</param>
        /// <returns></returns>
        public static void ExTranUpdate<T>(this DbContext DB, T entity) where T : class
        {
            DB.Set<T>().Attach(entity);
            DB.Entry<T>(entity).State = EntityState.Modified;
        }

        #endregion 

        #region 删除

        /// <summary>
        /// 删除
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="entity">实体对象</param>
        /// <returns></returns>
        public static int ExDelete<T>(this DbContext DB, T entity) where T : class
        {
            DB.Set<T>().Attach(entity);
            DB.Entry<T>(entity).State = EntityState.Deleted;
            var i = DB.SaveChanges();
            return i;
        }

        /// <summary>
        /// 条件删除
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="where">where条件表达式</param>
        /// <returns></returns>
        public static int ExDelete<T>(this DbContext DB, Expression<Func<T, bool>> where) where T : class
        {
            List<T> entitys = DB.Set<T>().Where(where).ToList();
            entitys.ForEach(m => DB.Entry<T>(m).State = EntityState.Deleted);
            var i = DB.SaveChanges();
            return i;
        }

        /// <summary>
        /// 条件事务删除
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="where">where条件表达式</param>
        /// <returns></returns>
        public static void ExTranDelete<T>(this DbContext DB, Expression<Func<T, bool>> where) where T : class
        {
            List<T> entitys = DB.Set<T>().Where(where).ToList();
            entitys.ForEach(m => DB.Entry<T>(m).State = EntityState.Deleted);
        }
        #endregion 删除

        #region 查询 分页


        ///// <summary>
        ///// 查询
        ///// </summary>
        ///// <typeparam name="T">类型</typeparam>
        ///// <param name="DB">DbContext上下文对象</param>
        ///// <param name="where">是否启用逻辑删除判断,默认禁用</param>
        ///// <returns></returns>
        //public static IEnumerable<T> ExSelect<T>(this DbContext DB, List<Guid> where) where T : class
        //{
        //    var entity = DB.Set<T>().AsQueryable();

        //    //ParameterExpression parameter = Expression.Parameter(typeof(T), "ex");
        //    //MemberExpression member1 = Expression.PropertyOrField(parameter, "Guid");
        //    //ConstantExpression constant1 = Expression.Constant(where, typeof(Guid));

        //    //var query = Expression.eq(member1, constant1);
        //    //var whereIsDelete = Expression.Lambda<Func<T, Boolean>>(query, parameter);
        //    //entity = entity.Where(whereIsDelete);


        //    ParameterExpression parameter = Expression.Parameter(typeof(T), "ex");
        //    MemberExpression member = Expression.PropertyOrField(parameter, "Guid");
        //    MethodInfo method = typeof(string).GetMethod("Contains", new[] { typeof(string) });
        //    ConstantExpression constant = Expression.Constant(where, typeof(string));

        //    var a = Expression.Lambda<Func<T, bool>>(Expression.Call(constant, method, member), parameter);

        //    entity = entity.Where(a);
        //    entity = entity.AsNoTracking();
        //    return entity;
        //}


        /// <summary>
        /// 查询
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB">DbContext上下文对象</param>
        /// <param name="where">是否启用逻辑删除判断,默认禁用</param>
        /// <returns></returns>
        public static IEnumerable<T> ExSelect<T>(this DbContext DB, Expression<Func<T, bool>> where = null) where T : class
        {
            var entity = DB.Set<T>().AsQueryable();

            entity = where != null ? entity.Where(where) : entity;
            entity = entity.AsNoTracking();
            return entity;
        }

        ///// <summary>
        ///// 查询
        ///// </summary>
        ///// <typeparam name="T">类型</typeparam>
        ///// <param name="DB">DbContext上下文对象</param>
        ///// <param name="where">where查询条件表达式</param>
        /////// <param name="falg">是否启用逻辑删除判断,默认禁用</param>
        ///// <returns></returns>
        //public static IEnumerable<T> ExSelect<T>(this DbContext DB, Expression<Func<T, bool>> where = null, Expression<Func<T, dynamic>> scalar = null) where T : class
        //{
        //    var entity = DB.Set<T>().AsQueryable();

        //    //if (falg)
        //    //{
        //    //    //获取逻辑删除处理
        //    //    ParameterExpression parameter = Expression.Parameter(typeof(T), "ex");
        //    //    MemberExpression member1 = Expression.PropertyOrField(parameter, "IsDelete");
        //    //    ConstantExpression constant1 = Expression.Constant(0, typeof(Nullable<int>));
        //    //    var query = Expression.Equal(member1, constant1);
        //    //    var whereIsDelete = Expression.Lambda<Func<T, Boolean>>(query, parameter);
        //    //    entity = entity.Where(whereIsDelete);
        //    //}

        //    entity = where != null ? entity.Where(where) : entity;
        //    entity = entity.AsNoTracking();

        //    if (scalar != null)
        //    {
        //        var a = entity.Select(scalar); 
        //        return a;
        //    }
        //    return entity;
        //}

        /// <summary>
        /// 通过主键查找实体对象
        /// </summary>
        /// <typeparam name="T">类型</typeparam>
        /// <param name="DB"></param>
        /// <param name="guid"></param>
        /// <returns></returns>
        public static T ExFind<T>(this DbContext DB, Guid guid) where T : class
        {
            var entity = DB.Set<T>().Find(guid);
            return entity;
        }

        /// <summary>
        /// 分页
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <typeparam name="Tkey"></typeparam>
        /// <param name="DB"></param>
        /// <param name="where"></param>
        /// <param name="orderBy"></param>
        /// <param name="page"></param>
        /// <returns></returns>

        public static IEnumerable<T> ExPage<T, Tkey>(this DbContext DB, Expression<Func<T, bool>> where, Expression<Func<T, Tkey>> orderBy, Model.Pages page) where T : class, new()
        {
            var entity = DB.Set<T>().AsQueryable();
            //if (page.order == "asc")
            //{
            //    entity = entity.ExOrderby(orderBy, false);
            //}
            //else
            //{
            //    entity = entity.ExOrderby(orderBy, true);
            //}

            //获取逻辑删除处理  
            var query = LambdaHelper<T>.Equal<int>("IsDelete", 0);

            if (string.IsNullOrWhiteSpace(page.sort))
            {
                page.sort = "CreateTime";
            }
            if (string.IsNullOrWhiteSpace(page.order))
            {
                page.order = "desc";
            }

            if (!string.IsNullOrWhiteSpace(page.sort))
            {
                var sortExpression = Expression.Parameter(entity.ElementType);
                var selector = Expression.Lambda(Expression.PropertyOrField(sortExpression, page.sort), sortExpression);
                if (page.order.ToLower() == "asc")
                {
                    entity = (IQueryable<T>)entity.Provider.CreateQuery(Expression.Call(typeof(Queryable), "OrderBy",
                                                                    new Type[] { entity.ElementType, selector.Body.Type },
                                                                    entity.Expression, selector));
                }
                else if (page.order.ToLower() == "desc")
                {
                    entity = (IQueryable<T>)entity.Provider.CreateQuery(Expression.Call(typeof(Queryable), "OrderByDescending",
                                                                    new Type[] { entity.ElementType, selector.Body.Type },
                                                                    entity.Expression, selector));
                }
            }

            //var lambdaWhere = Expression.Lambda<Func<T, bool>>(query, parameter);
            entity = entity.Where(query);

            entity = where != null ? entity.Where(where) : entity;
            page.total = where != null ? entity.Count(where) : entity.Count();

            if (page.total <= 0)
            {
                return new List<T>();
            }
            var obj = entity.Skip((page.pageNumber - 1) * page.pageSize).Take(page.pageSize).AsNoTracking().ToList();
            return obj;
        }

        /// <summary>
        /// 分页
        /// </summary>
        /// <typeparam name="T"></typeparam> 
        /// <param name="DB"></param>
        /// <param name="where"></param> 
        /// <param name="page"></param>
        /// <returns></returns>

        public static IEnumerable<T> ExPage<T>(this DbContext DB, Expression<Func<T, bool>> where, Model.Pages page) where T : class, new()
        {
            var entity = DB.Set<T>().AsQueryable();

            //获取逻辑删除处理  
            var query = CommonLib.LambdaHelper<T>.Equal<int>("IsDelete", 0);
            if (string.IsNullOrWhiteSpace(page.sort))
            {
                page.sort = "CreateTime";
            }
            if (string.IsNullOrWhiteSpace(page.order))
            {
                page.order = "desc";
            }

            if (!string.IsNullOrWhiteSpace(page.sort))
            {
                var sortExpression = Expression.Parameter(entity.ElementType);
                var selector = Expression.Lambda(Expression.PropertyOrField(sortExpression, page.sort), sortExpression);
                if (page.order.ToLower() == "asc")
                {
                    entity = (IQueryable<T>)entity.Provider.CreateQuery(Expression.Call(typeof(Queryable), "OrderBy",
                                                                    new Type[] { entity.ElementType, selector.Body.Type },
                                                                    entity.Expression, selector));
                }
                else if (page.order.ToLower() == "desc")
                {
                    entity = (IQueryable<T>)entity.Provider.CreateQuery(Expression.Call(typeof(Queryable), "OrderByDescending",
                                                                    new Type[] { entity.ElementType, selector.Body.Type },
                                                                    entity.Expression, selector));
                }
            }

            //var lambdaWhere = Expression.Lambda<Func<T, bool>>(query, parameter);
            entity = entity.Where(query);

            entity = where != null ? entity.Where(where) : entity;
            page.total = where != null ? entity.Count(where) : entity.Count();

            if (page.total <= 0)
            {
                return new List<T>();
            }
            var obj = entity.Skip((page.pageNumber - 1) * page.pageSize).Take(page.pageSize).AsNoTracking().ToList();
            return obj;
        }

        #endregion  查询 分页

        #region 排序

        /// <summary>
        /// 单个排序通用方法
        /// </summary>
        /// <typeparam name="T">实体类型</typeparam>
        /// <typeparam name="Tkey">排序类型</typeparam>
        /// <param name="source">源</param>
        /// <param name="orderby">排序条件表达式</param>
        /// <param name="isDesc">是否desc 默认否</param>
        /// <returns></returns>
        public static IQueryable<T> ExOrderby<T, Tkey>(this IQueryable<T> source, Expression<Func<T, Tkey>> orderby, bool isDesc = false) where T : class
        {
            if (isDesc)
            {
                return source.OrderByDescending(orderby);
            }
            else
            {
                return source.OrderBy(orderby);
            }
        }

        /// <summary>
        /// 多个排序通用方法
        /// </summary>
        /// <typeparam name="T">实体类型</typeparam>
        /// <param name="source">源 IQueryable</param>
        /// <param name="orderByExpression">排序数组</param>
        /// <returns></returns>
        public static IQueryable<T> ExOrderby<T>(this IQueryable<T> source, OrderModelField[] orderByExpression) where T : class
        {
            var parameter = Expression.Parameter(typeof(T), "o");  //创建表达式变量参数

            if (orderByExpression != null && orderByExpression.Length > 0)
            {
                for (int i = 0; i < orderByExpression.Length; i++)
                {
                    //根据属性名获取属性
                    var property = typeof(T).GetProperty(orderByExpression[i].propertyName);
                    //创建一个访问属性的表达式
                    var propertyAccess = Expression.MakeMemberAccess(parameter, property);
                    var orderByExp = Expression.Lambda(propertyAccess, parameter);

                    string OrderName = "";
                    if (i > 0)
                    {
                        OrderName = orderByExpression[i].isDesc ? "ThenByDescending" : "ThenBy";
                    }
                    else
                    {
                        OrderName = orderByExpression[i].isDesc ? "OrderByDescending" : "OrderBy";
                    }

                    MethodCallExpression resultExp = Expression.Call(typeof(Queryable), OrderName, new Type[] { typeof(T), property.PropertyType }, source.Expression, Expression.Quote(orderByExp));
                    source = source.Provider.CreateQuery<T>(resultExp);
                }
            }
            return source;
        }

        #endregion 排序

        #region in 操作

        /// <summary>
        /// 扩展方法  支持 in 操作
        /// </summary>
        /// <typeparam name="T">需要扩展的对象类型</typeparam>
        /// <typeparam name="TKey">in 的值类型</typeparam>
        /// <param name="source">需要扩展的对象</param>
        /// <param name="valueSelector">值选择器 例如c=>c.UserId</param>
        /// <param name="values">值集合</param>
        /// <returns></returns>
        public static IQueryable<T> ExWhereIn<T, TKey>(this IQueryable<T> source, Expression<Func<T, TKey>> valueSelector, IEnumerable<TKey> values) where T : class
        {
            if (null == valueSelector) { throw new ArgumentNullException("valueSelector"); }
            if (null == values) { throw new ArgumentNullException("values"); }
            ParameterExpression p = valueSelector.Parameters.Single();

            if (!values.Any())
            {
                return source;
            }
            var equals = values.Select(value => (Expression)Expression.Equal(valueSelector.Body, Expression.Constant(value, typeof(TKey))));
            var body = equals.Aggregate<Expression>((accumulate, equal) => Expression.Or(accumulate, equal));
            return source.Where(Expression.Lambda<Func<T, bool>>(body, p));
        }

        #endregion in操作

        #endregion EntityFramework

        #region sql形式

        /// <summary>
        /// 
        /// </summary>
        /// <param name="DB"></param>
        /// <param name="sql"></param>
        /// <param name="paras"></param>
        /// <returns></returns>
        public static int ExecuteSql(this DbContext DB, string sql, params object[] paras)
        {
            var i = DB.Database.ExecuteSqlCommand(sql, paras);
            return i;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="DB"></param>
        /// <param name="sql"></param>
        /// <param name="paras"></param>
        /// <returns></returns>
        public static IEnumerable<T> ExSelectListBySql<T>(this DbContext DB, string sql, params object[] paras) where T : class
        {
            var entity = DB.Database.SqlQuery<T>(sql, paras);
            return entity;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="DB"></param>
        /// <param name="sql"></param>
        /// <param name="paras"></param>
        /// <returns></returns>
        public static T ExSelectBySql<T>(this DbContext DB, string sql, params object[] paras) where T : class
        {
            var entity = DB.Database.SqlQuery<T>(sql, paras).FirstOrDefault();
            return entity;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="DB"></param>
        /// <param name="listSql"></param>
        /// <param name="listParas"></param>
        public static void ExecuteTransaction(this DbContext DB, List<string> listSql, List<object[]> listParas)
        {
            using (var tran = DB.Database.BeginTransaction())
            {
                try
                {
                    for (int j = 0; j < listSql.Count; j++)
                    {
                        if (listParas != null && listParas.Count > 0)
                        {
                            DB.Database.ExecuteSqlCommand(listSql[j], listParas[j]);
                        }
                    }
                    foreach (String item in listSql)
                    {
                        DB.Database.ExecuteSqlCommand(item);
                    }
                    tran.Commit();
                }
                catch (Exception ex)
                {
                    tran.Rollback();
                    throw ex;
                }
            }
        }

        #endregion sql形式

        #region 排序字段

        /// <summary>
        /// 排序字段
        /// </summary>
        public class OrderModelField
        {
            /// <summary>
            /// 排序
            /// </summary>
            public bool isDesc { get; set; }

            /// <summary>
            /// 排序字段
            /// </summary>
            public string propertyName { get; set; }
        }

        #endregion 排序字段
    }
}
